
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>Transformer 解读 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="ViT解读" href="ViT%E8%A7%A3%E8%AF%BB.html" />
    <link rel="prev" title="文章结构" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.1%20ResNet.html">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第十章：常见代码解读
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ViT%E8%A7%A3%E8%AF%BB.html">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第十章/Transformer 解读.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第十章/Transformer 解读.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第十章/Transformer 解读.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   整体架构
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#attention">
   Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     何为 Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#self-attention">
     Self-Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#multi-head-attention">
     Multi-Head Attention
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#encoder">
   Encoder
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#decoder">
   Decoder
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   残差连接
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#mask">
   Mask
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id5">
   位置编码
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   最终建模
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id7">
   训练
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#traning-loop">
     Traning Loop
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id8">
     优化器
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id9">
     正则化
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id10">
   总结
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id11">
     优点
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id12">
     缺点
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id13">
   参考材料：
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>Transformer 解读</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   整体架构
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#attention">
   Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     何为 Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#self-attention">
     Self-Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#multi-head-attention">
     Multi-Head Attention
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#encoder">
   Encoder
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#decoder">
   Decoder
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   残差连接
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#mask">
   Mask
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id5">
   位置编码
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   最终建模
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id7">
   训练
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#traning-loop">
     Traning Loop
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id8">
     优化器
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id9">
     正则化
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id10">
   总结
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id11">
     优点
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id12">
     缺点
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id13">
   参考材料：
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="transformer">
<h1>Transformer 解读<a class="headerlink" href="#transformer" title="永久链接至标题">#</a></h1>
<p>
<font size=3><b>[Transformer] Attention Is All You Need</b></font>
<br>
<font size=2>Ashish Vaswani，Noam Shazeer，Niki Parmar，Jakob Uszkoreit，Llion Jones，Aidan N. Gomez，Łukasz Kaiser，Illia Polosukhin.</font>
<br>
<font size=2>NIPS 2017.</font>
<a href='https://arxiv.org/pdf/1706.03762.pdf'>[paper]</a> <a href='https://github.com/tensorflow/tensor2tensor'>[tensorflow]</a> <a href='https://github.com/harvardnlp/annotated-transformer'>[pytorch]</a> 
<br>
<font size=3>解读者：邹雨衡，对外经济贸易大学本科生，牛志康，西安电子科技大学本科生</font>
<br>
</p>
<section id="id1">
<h2>前言<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
<p>​Attention （注意力）机制及其变体已经成为自然语言处理任务的一种基本网络架构，甚至在一定程度上取代了 RNN （循环神经网络）的主流地位，2018年取得了自然语言处理多项任务效果大幅提升的 BERT 模型便是基于 attention 机制搭建。但自2014年 attention 机制被提出至2017年，虽然被广泛应用到深度学习的各领域中，attention 机制更多的是作为 CNN（卷积神经网络）、RNN 的组件存在，自然语言处理的各种任务仍然是以 RNN 及其变体如 LSTM（长短期记忆递归神经网络）为主。</p>
<p>​但是 RNN、LSTM 虽然在处理自然语言处理的序列建模任务中得天独厚，却也有着难以忽视的缺陷：</p>
<p>​1. RNN 为单向依序计算，序列需要依次输入、串行计算，限制了计算机的并行计算能力，导致时间成本过高。</p>
<p>​2. RNN 难以捕捉长期依赖问题，即对于极长序列，RNN 难以捕捉远距离输入之间的关系。虽然 LSTM 通过门机制对此进行了一定优化，但 RNN 对长期依赖问题的捕捉能力依旧是不如人意的。</p>
<p>​针对上述两个问题，2017年，Vaswani 等人发表了论文《Attention Is All You Need》，抛弃了传统的 CNN、RNN 架构，提出了一种全新的完全基于 attention 机制的模型——Transformer，解决了上述问题，在较小的时间成本下取得了多个任务的 the-state-of-art 效果，并为自然语言处理任务提供了新的思路。自此，attention 机制进入自然语言处理任务的主流架构，众多性能卓越的预训练模型都基于 Transformer 架构提出，例如 BERT、OpenAI GPT 等。</p>
<p>​本文将从模型原理及代码实现上讲解该模型，并着重介绍代码实现。需要注意的是，由于 Transformer 源代码使用 TensorFlow 搭建，此处选择了哈佛大学 harvardnlp 团队基于 Pytorch 框架开发的 <a class="reference external" href="https://github.com/harvardnlp/annotated-transformer">Annotated Transformer</a> 代码进行讲解，以帮助大家了解 Transformer 的实现细节。</p>
</section>
<section id="id2">
<h2>整体架构<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p>​	Transformer 是针对自然语言处理的 Seq2Seq（序列到序列）任务开发的，整体上沿用了 Seq2Seq 模型的 Encoder-Decoder（编码器-解码器）结构，整体架构如下：</p>
<div align=center><img src="./figures/transformer_architecture.png" alt="image-20230127193646262" style="zoom:50%;"/></div>
<p>​	Transformer 由一个 Encoder，一个 Decoder 外加一个 Softmax 分类器与两层编码层构成。上图中左侧方框为 Encoder，右侧方框为 Decoder。</p>
<p>​	由于是一个 Seq2Seq 任务，在训练时，Transformer 的训练语料为若干个句对，具体子任务可以是机器翻译、阅读理解、机器对话等。在原论文中是训练了一个英语与德语的机器翻译任务。在训练时，句对会被划分为输入语料和输出语料，输入语料将从左侧通过编码层进入 Encoder，输出语料将从右侧通过编码层进入 Decoder。Encoder 的主要任务是对输入语料进行编码再输出给 Decoder，Decoder 再根据输出语料的历史信息与 Encoder 的输出进行计算，输出结果再经过一个线性层和 Softmax 分类器即可输出预测的结果概率，整体逻辑如下图：</p>
<div align=center><img src="./figures/transformer_datalink.png" alt="image-20230127193646262" style="zoom:50%;"/></div>
<p>​	模型整体实现为一个 Encoder-Decoder架构：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    A standard Encoder-Decoder architecture. Base for this and many</span>
<span class="sd">    other models.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">,</span> <span class="n">generator</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">EncoderDecoder</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
        <span class="c1"># 编码器</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
        <span class="c1"># 解码器</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
        <span class="c1"># 输入语料的编码函数</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span>
        <span class="c1"># 输出语料的编码函数</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span>
        <span class="c1"># 线性层+softmax分类层，输出分类概率，即架构图中的最高层</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">):</span>
        <span class="c1"># 前向计算函数</span>
        <span class="s2">&quot;Take in and process masked src and target sequences.&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">):</span>
        <span class="c1"># src 为输入语料</span>
        <span class="c1"># src_mask 为输入语料的遮蔽符号</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span>
        <span class="c1"># 先编码输入语料，再和遮蔽符号一起传入编码器中</span>

    <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span>
        <span class="c1"># 解码函数类似编码函数,memory为记忆，其实就是编码器的输出</span>
</pre></div>
</div>
<p>​	在该代码中实现的前向计算函数即先通过 Encoder 编码，再输出给 Decoder 解码。最终的线性分类并未在该类中实现。</p>
<p>​	接下来将逐个介绍 Transformer 的实现细节和原理。</p>
</section>
<section id="attention">
<h2>Attention<a class="headerlink" href="#attention" title="永久链接至标题">#</a></h2>
<p>​Attention 机制是 Transformer 的核心之一，要详细介绍Attention 机制的思想与操作需要较多的篇幅与笔墨，此处我们仅简要概述 attention 机制的思想和大致计算方法，更多细节请大家具体查阅相关资料，例如：<a class="reference external" href="https://towardsdatascience.com/attaining-attention-in-deep-learning-a712f93bdb1e">Understanding Attention In Deep Learning (NLP)</a>、<a class="reference external" href="https://lilianweng.github.io/posts/2018-06-24-attention/">Attention? Attention!</a>等。在下文中，我们将从何为Attention、self-attention 和 Multi-Head attention 三个方面逐步介绍 Transformer 中使用的Attention 机制。</p>
<section id="id3">
<h3>何为 Attention<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h3>
<p>​Attention 机制最先源于计算机视觉领域，其核心思想为当我们关注一张图片，我们往往无需看清楚全部内容而仅将注意力集中在重点部分即可。而在自然语言处理领域，我们往往也可以通过将重点注意力集中在一个或几个 token，从而取得更高效高质的计算效果。</p>
<p>​Attention 机制的特点是通过计算 <strong>Query</strong> (查询值)与<strong>Key</strong>(键值)的相关性为真值加权求和，从而拟合序列中每个词同其他词的相关关系。其大致计算过程为：</p>
<div align=center><img src="./figures/transformer_attention.png" alt="image-20230129185638102" style="zoom:50%;"/></div>
<ol class="simple">
<li><p>通过输入与参数矩阵，得到查询值<span class="math notranslate nohighlight">\(q\)</span>，键值<span class="math notranslate nohighlight">\(k\)</span>，真值<span class="math notranslate nohighlight">\(v\)</span>。可以理解为，<span class="math notranslate nohighlight">\(q\)</span> 是计算注意力的另一个句子（或词组），<span class="math notranslate nohighlight">\(v\)</span> 为待计算句子，<span class="math notranslate nohighlight">\(k\)</span> 为待计算句子中每个词（即 <span class="math notranslate nohighlight">\(v\)</span> 的每个词）的对应键。其中，<span class="math notranslate nohighlight">\(v\)</span> 与 <span class="math notranslate nohighlight">\(k\)</span> 的矩阵维度是相同的，<span class="math notranslate nohighlight">\(q\)</span>的矩阵维度则可以不同，只需要满足 <span class="math notranslate nohighlight">\(q\)</span> 能和<span class="math notranslate nohighlight">\(k^T\)</span>满足矩阵相乘的条件即可。</p></li>
<li><p>对 <span class="math notranslate nohighlight">\(q\)</span> 的每个元素 <span class="math notranslate nohighlight">\(q_i\)</span> ,对 <span class="math notranslate nohighlight">\(q_i\)</span> 与 <span class="math notranslate nohighlight">\(k\)</span> 做点积并进行 softmax，得到一组向量，该向量揭示了 <span class="math notranslate nohighlight">\(q_i\)</span> 对整个句子每一个位置的注意力大小。</p></li>
<li><p>以上一步输出向量作为权重，对 <span class="math notranslate nohighlight">\(v\)</span> 进行加权求和，将 <span class="math notranslate nohighlight">\(q\)</span> 的所有元素得到的加权求和结果拼接得到最后输出。</p></li>
</ol>
<p>​其中，q，k，v 分别是由输入与三个参数矩阵做积得到的：</p>
<div align=center><img src="./figures/transformer_attention_compute.png" alt="image-20230129185638102" style="zoom:50%;"/></div>
<p>​在实际训练过程中，为提高并行计算速度，直接使用 <span class="math notranslate nohighlight">\(q\)</span>、<span class="math notranslate nohighlight">\(k\)</span>、<span class="math notranslate nohighlight">\(v\)</span> 拼接而成的矩阵进行一次计算即可。</p>
<p>​具体到 Transformer 模型中，计算公式如下：
$<span class="math notranslate nohighlight">\(
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
\)</span><span class="math notranslate nohighlight">\(
​其中，\)</span>d_k<span class="math notranslate nohighlight">\( 为键向量\)</span>k<span class="math notranslate nohighlight">\(的维度，除以根号\)</span>d_k$主要作用是在训练过程中获得一个稳定的梯度。计算示例如下图：</p>
<div align=center><img src="./figures/transformer_attention_compute_2.png" alt="image-20230129185638102" style="zoom:50%;"/></div>
<p>​	Attention 机制的基本实现代码如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">attention</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="s2">&quot;Compute &#39;Scaled Dot Product Attention&#39;&quot;</span>
    <span class="n">d_k</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># 获取键向量的维度，键向量的维度和值向量的维度相同</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span>
    <span class="c1"># 计算Q与K的内积并除以根号dk</span>
    <span class="c1"># 为什么使用transpose——内积的计算过程</span>
    <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1e9</span><span class="p">)</span>
        <span class="c1"># masker_fill为遮蔽，即基于一个布尔值的参数矩阵对矩阵进行遮蔽</span>
        <span class="c1"># 此处同上面的subsequent_mask函数结合，此处的mask即为该函数的输出</span>
    <span class="n">p_attn</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">dropout</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
        <span class="n">p_attn</span> <span class="o">=</span> <span class="n">dropout</span><span class="p">(</span><span class="n">p_attn</span><span class="p">)</span>
        <span class="c1"># 采样</span>
    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">p_attn</span><span class="p">,</span> <span class="n">value</span><span class="p">),</span> <span class="n">p_attn</span>
    <span class="c1"># 根据计算结果对value进行加权求和</span>
</pre></div>
</div>
</section>
<section id="self-attention">
<h3>Self-Attention<a class="headerlink" href="#self-attention" title="永久链接至标题">#</a></h3>
<p>​从上对 Attention 机制原理的叙述中我们可以发现，Attention 机制的本质是对两段序列的元素依次进行相似度计算，寻找出一个序列的每个元素对另一个序列的每个元素的相关度，然后基于相关度进行加权，即分配注意力。而这两段序列即是我们计算过程中 <span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、<span class="math notranslate nohighlight">\(V\)</span> 的来源。</p>
<p>​在经典的 Attention 机制中，<span class="math notranslate nohighlight">\(Q\)</span> 往往来自于一个序列，<span class="math notranslate nohighlight">\(K\)</span> 与 <span class="math notranslate nohighlight">\(V\)</span> 来自于另一个序列，都通过参数矩阵计算得到，从而可以拟合这两个序列之间的关系。例如在 Transformer 的 Decoder 结构中，<span class="math notranslate nohighlight">\(Q\)</span> 来自于 Encoder 的输出，<span class="math notranslate nohighlight">\(K\)</span> 与 <span class="math notranslate nohighlight">\(V\)</span> 来自于 Decoder 的输入，从而拟合了编码信息与历史信息之间的关系，便于综合这两种信息实现未来的预测。</p>
<p>​但在 Transformer 的 Encoder 结构中，使用的是 Attention 机制的变种—— self-attention （自注意力）机制。所谓自注意力，即是计算本身序列中每个元素都其他元素的注意力分布，即在计算过程中，<span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、<span class="math notranslate nohighlight">\(V\)</span> 都由同一个输入通过不同的参数矩阵计算得到。在 Encoder 中，<span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、<span class="math notranslate nohighlight">\(V\)</span> 分别是输入对参数矩阵 <span class="math notranslate nohighlight">\(W_q\)</span>、<span class="math notranslate nohighlight">\(W_k\)</span>、<span class="math notranslate nohighlight">\(W_v\)</span> 做积得到，从而拟合输入语句中每一个 token 对其他所有 token 的关系。</p>
<p>​例如，通过 Encoder 中的 self-attention 层，可以拟合下面语句中 it 对其他 token 的注意力分布如图：</p>
<div align=center><img src="./figures/transformer_selfattention.jpg" alt="image-20230129190819407" style="zoom:50%;"/></div>
<p>​在代码中的实现，self-attention 机制其实是通过给 <span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、<span class="math notranslate nohighlight">\(V\)</span> 的输入传入同一个参数实现的：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">))</span>
<span class="c1"># 传入自注意力层作为sublayer,输出自注意力层计算并进行残差连接后的结果</span>
</pre></div>
</div>
<p>​上述代码是 Encoder 层的部分实现，self_attn 即是注意力层，传入的三个参数都是 <span class="math notranslate nohighlight">\(x\)</span>，分别是 <span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、<span class="math notranslate nohighlight">\(V\)</span> 的计算输入，从而 <span class="math notranslate nohighlight">\(Q\)</span>、<span class="math notranslate nohighlight">\(K\)</span>、$$ 均来源于同一个输入，则实现了自注意力的拟合。</p>
</section>
<section id="multi-head-attention">
<h3>Multi-Head Attention<a class="headerlink" href="#multi-head-attention" title="永久链接至标题">#</a></h3>
<p>​Attention 机制可以实现并行化与长期依赖关系拟合，但一次注意力计算只能拟合一种相关关系，单一的 Attention 机制很难全面拟合语句序列里的相关关系。因此 Transformer 使用了 Multi-Head attention 机制，即同时对一个语料进行多次注意力计算，每次注意力计算都能拟合不同的关系，将最后的多次结果拼接起来作为最后的输出，即可更全面深入地拟合语言信息。</p>
<div align=center><img src="./figures/transformer_Multi-Head attention.png" alt="image-20230129190819407" style="zoom:50%;"/></div>
<p>​	在原论文中，作者也通过实验证实，多头注意力计算中，每个不同的注意力头能够拟合语句中的不同信息，如下图：</p>
<div align=center><img src=".\figures\transformer_Multi-Head visual.jpg" alt="image-20230207203454545" style="zoom:50%;" />	</div>
<p>​上层与下层分别是两个注意力头对同一段语句序列进行自注意力计算的结果，可以看到，对于不同的注意力头，能够拟合不同层次的相关信息。通过多个注意力头同时计算，能够更全面地拟合语句关系。</p>
<p>​Multi-Head attention 的整体计算流程如下：</p>
<div align=center><img src="./figures/transformer_Multi-Head attention_compute.png" alt="image-20230129190819407" style="zoom:50%;"/></div>
<p>所谓的多头注意力机制其实就是将原始的输入序列进行多组的自注意力处理；然后再将每一组得到的自注意力结果拼接起来，再通过一个线性层进行处理，得到最终的输出。我们用公式可以表示为：
$<span class="math notranslate nohighlight">\(
\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ...,
\mathrm{head_h})W^O    \\
    \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
\)</span>$</p>
<p>​其代码实现相对复杂，通过矩阵操作实现并行的多头计算，整体计算流程如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MultiHeadedAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># 多头注意力操作</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="s2">&quot;Take in model size and number of heads.&quot;</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">MultiHeadedAttention</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">d_model</span> <span class="o">%</span> <span class="n">h</span> <span class="o">==</span> <span class="mi">0</span>
        <span class="c1"># 断言，控制h总是整除于d_model，如果输入参数不满足将报错</span>
        <span class="c1"># We assume d_v always equals d_k</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">h</span>
        <span class="c1"># key的长度</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">h</span> <span class="o">=</span> <span class="n">h</span>
        <span class="c1"># 头数</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">linears</span> <span class="o">=</span> <span class="n">clones</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="mi">4</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="s2">&quot;Implements Figure 2&quot;</span>
        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="c1"># Same mask applied to all h heads.</span>
            <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">nbatches</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="c1"># 1) Do all the linear projections in batch from d_model =&gt; h x d_k</span>
        <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">nbatches</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">lin</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linears</span><span class="p">,</span> <span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">))</span>
        <span class="p">]</span>

        <span class="c1"># 2) Apply attention on all the projected vectors in batch.</span>
        <span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span>
            <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span>
        <span class="p">)</span>
        <span class="c1"># x 为加权求和结果，attn为计算的注意力分数</span>

        <span class="c1"># 3) &quot;Concat&quot; using a view and apply a final linear.</span>
        <span class="n">x</span> <span class="o">=</span> <span class="p">(</span>
            <span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
            <span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
            <span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">nbatches</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
        <span class="p">)</span>

        <span class="k">del</span> <span class="n">query</span>
        <span class="k">del</span> <span class="n">key</span>
        <span class="k">del</span> <span class="n">value</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linears</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>​在 Pytorch 中，其实提供了 Multi-Head Attention 机制的 API，可以通过下列代码直接构造并使用一个多头注意力层：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">multihead_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span> <span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="c1"># 构造一个多头注意力层</span>
<span class="c1"># embed_dim :输出词向量长度；num_heads :头数</span>
<span class="c1"># 可选参数：</span>
<span class="c1"># drop_out: 在采样层中drop_out的概率，默认为0；bias：线性层是否计算偏置</span>
<span class="c1"># add_bias_kv：是否将偏置添加到 K 和 V 中；add_zero_attn：是否为 K 和 V 再添加一串为0的序列</span>
<span class="c1"># kdim：K 的总共feature数；vdim：V 的总共feature数；batch_first：是否将输入调整为(batch,seq,feature)</span>
<span class="n">attn_output</span><span class="p">,</span> <span class="n">attn_output_weights</span> <span class="o">=</span> <span class="n">multihead_attn</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="c1"># 前向计算</span>
<span class="c1"># 输出：</span>
<span class="c1"># attn_output：形如(L,N,E)的计算结果，L为目标序列长度，N为批次大小，E为embed_dim</span>
<span class="c1"># attn_output_weights：注意力计算分数，仅当need_weights=True时返回</span>
<span class="c1"># query、key、value 分别是注意力计算的三个输入矩阵</span>
</pre></div>
</div>
</section>
</section>
<section id="encoder">
<h2>Encoder<a class="headerlink" href="#encoder" title="永久链接至标题">#</a></h2>
<div align=center><img src="./figures/transformer_Encoder.png" alt="image-20230129182417098" style="zoom:50%;"/></div>
<p>​如图所示，Encoder 由 N 个（论文中取 N = 6）EncoderLayer 组成，每个 EncoderLayer 又由两个 sublayer （子层）组成。在下文的代码中，每一个 layer 是一个 EncoderLayer，代码在最后又额外加入了一个标准化层进行标准化操作：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s2">&quot;Core encoder is a stack of N layers&quot;</span>
    <span class="c1"># 编码器其实是由N层自注意力层+标准化层构成的</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">,</span> <span class="n">N</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Encoder</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clones</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
        <span class="c1"># clones 函数为作者定义的，实现将 layer 层复制 N 次的功能</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
        <span class="c1"># 标准化层，参数为特征数</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">):</span>
        <span class="s2">&quot;Pass the input (and mask) through each layer in turn.&quot;</span>
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>​1. 第一部分为一个多头自注意力层。Transformer 的最大特点即是抛弃了传统的 CNN、RNN 架构，完全使用 attention 机制搭建。在 EncoderLayer 中，使用了 Multi-Head self-attention（多头自注意力）层，编码输入语料的相关关系。通过 attention 机制，实现了模型的并行化与长期依赖关系的拟合。关于 Multi-Head self-attention 将在之后的板块详述其特点与实现。在通过一个多头自注意力层后，输出经过标准化层进入下一个 sublayer。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">EncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s2">&quot;Encoder is made up of self-attn and feed forward (defined below)&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">self_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">EncoderLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
        <span class="c1"># 自注意力层</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
        <span class="c1"># 前馈神经网络层</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span> <span class="o">=</span> <span class="n">clones</span><span class="p">(</span><span class="n">SublayerConnection</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span>
        <span class="c1"># 残差连接层，将在之后定义</span>
        <span class="c1"># 两层分别用于连接自注意力层和前馈神经网络层</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">size</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">):</span>
        <span class="s2">&quot;Follow Figure 1 (left) for connections.&quot;</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">))</span>
        <span class="c1"># 传入自注意力层作为sublayer,输出自注意力层计算并进行残差连接后的结果</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">)</span>
</pre></div>
</div>
<p>​2. 第二部分为一个全连接神经网络，论文中称为“position-wise fully connected feed-forward network”，实际是一个线性层+激活函数+ dropout + 线性层的全连接网络。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">PositionwiseFeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># 即架构图中的FeedForward</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">PositionwiseFeedForward</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">w_1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">w_2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">w_2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">w_1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">relu</span><span class="p">()))</span>
        <span class="c1"># 实则是一个线性层+激活函数+采样+一个线性层</span>
</pre></div>
</div>
</section>
<section id="decoder">
<h2>Decoder<a class="headerlink" href="#decoder" title="永久链接至标题">#</a></h2>
<div align=center><img src="./figures/transformer_Decoder.png" alt="image-20230129183826948" style="zoom:50%;"></div>
<p>​Decoder 与 Encoder 的组成非常类似，同样是由N个DecoderLayer组成，DecoderLayer 与 EncoderLayer 的区别在于：</p>
<p>​1. EncoderLayer由两个sublayer组成，分别是一个多头自注意力层与一个全连接网络层。DecoderLayer 则在 EncoderLayer 的两个 sublayer 的基础上增加了一个带掩码的多头注意力层。如图，最下方的多头自注意力层同 EncoderLayer结构相同，中间则是增加的一个多头注意力层，将使用 Encoder 的输出和下方多头自注意力层的输出作为输入进行注意力计算。最上方的全连接网络也同 EncoderLayer 相同。</p>
<p>​2. EncoderLayer 的输入仅来自于编码之后的输入语料（或上一个 EncoderLayer 的输出），DecoderLayer 的输入除来自于编码之后的输出语料外，还包括 Encoder 的输出。</p>
<p>​DecoderLayer 的实现如下代码：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">DecoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s2">&quot;Decoder is made of self-attn, src-attn, and feed forward (defined below)&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">self_attn</span><span class="p">,</span> <span class="n">src_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">DecoderLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span>
        <span class="c1"># 语料注意力，即图中Decoder的第二个注意力网络，拟合编码器输出和上一个子层输出的注意力</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span> <span class="o">=</span> <span class="n">clones</span><span class="p">(</span><span class="n">SublayerConnection</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">),</span> <span class="mi">3</span><span class="p">)</span>
        <span class="c1"># 因为解码器有三个子层，所以需要三个残差连接层</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">):</span>
        <span class="s2">&quot;Follow Figure 1 (right) for connections.&quot;</span>
        <span class="n">m</span> <span class="o">=</span> <span class="n">memory</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">))</span>
        <span class="c1"># 由此可见解码器中两个注意力子层的区别</span>
        <span class="c1"># 第一个注意力子层同编码器中相同，是自注意力层，传入的三个参数均通过输入语料x得到</span>
        <span class="c1"># 第二个注意力子层为注意力层，传入的三个参数分别由输入语料x与记忆（即编码器输出）m得到</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">2</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">)</span>
</pre></div>
</div>
<p>​	注意此处为了计算输出语料与输入语料的注意力，在前向计算时，第二个注意力子层将传入 memory，也就是 Encoder 的输出。Decoder 的整体实现同 Encoder 类似，此处就不再赘述。</p>
</section>
<section id="id4">
<h2>残差连接<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
<p>​为了避免模型退化，Transformer 采用了残差连接的思想来连接每一个子层。残差连接，即下一层的输入不仅是上一层的输出，还包括上一层的输入。残差连接允许最底层信息直接传到最高层，让高层专注于残差的学习。</p>
<p>​例如，在 Encoder 中，在第一个子层，输入进入多头自注意力层的同时会直接传递到该层的输出，然后该层的输出会与原输入相加，再进行标准化。在第二个子层也是一样。即：
$<span class="math notranslate nohighlight">\(
\rm output = LayerNorm(x + Sublayer(x))
\)</span>$
​LayerNorm 为该层的标准化操作，Sublayer 为该子层的操作（多头注意力计算或全连接计算）。源码通过构造一个 SublayerConnection 类来实现残差连接：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">SublayerConnection</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    A residual connection followed by a layer norm.</span>
<span class="sd">    Note for code simplicity the norm is first as opposed to last.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">SublayerConnection</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
        <span class="c1"># 采样层</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">sublayer</span><span class="p">):</span>
        <span class="s2">&quot;Apply residual connection to any sublayer with the same size.&quot;</span>
        <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">sublayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
        <span class="c1"># 残差连接，在该层实现了norm和add</span>
</pre></div>
</div>
<p>​回到 EncoderLayer 的定义，我们可以看到作者在该层实现了两个 SublayerConnection 对象，分别实现自注意力层到全连接层、全连接层到输出的残差连接：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">EncoderLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s2">&quot;Encoder is made up of self-attn and feed forward (defined below)&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">self_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">EncoderLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span> <span class="o">=</span> <span class="n">clones</span><span class="p">(</span><span class="n">SublayerConnection</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span>
        <span class="c1"># 残差连接层</span>
        <span class="c1"># 两层分别用于连接自注意力层和前馈神经网络层</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">size</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">):</span>
        <span class="s2">&quot;Follow Figure 1 (left) for connections.&quot;</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">))</span>
        <span class="c1"># 传入自注意力层作为sublayer,输出自注意力层计算并进行残差连接后的结果</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">[</span><span class="mi">1</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">)</span>
</pre></div>
</div>
<p>​DecoderLayer 中也同样实现了该对象，不过由于有三个 sublayer，所以实现了三个 SublayerConnection 对象。</p>
</section>
<section id="mask">
<h2>Mask<a class="headerlink" href="#mask" title="永久链接至标题">#</a></h2>
<p>​Transformer 是一个自回归模型，类似于语言模型，其将利用历史信息依序对输出进行预测。例如，如果语料的句对为：1. &lt;BOS&gt; 我爱你 &lt;EOS&gt;；2. &lt;BOS&gt; I like you &lt;EOS&gt;。则 Encoder 获取的输入将会是句 1 整体，并输出句 1 的编码信息，但 Decoder 的输入并不一开始就是句 2 整体，而是先输入起始符&lt;BOS&gt;，Decoder 根据&lt;BOS&gt; 与 Encoder 的输出预测 I，再输入&lt;BOS&gt; I，Decoder 根据输入和 Encoder 的输出预测 like。因此，自回归模型需要对输入进行 mask（遮蔽），以保证模型不会使用未来信息预测当下。</p>
<p>​关于自回归模型与自编码模型的细节，感兴趣的读者可以下来查阅更多资料，在此提供部分链接供读者参考。博客：<a class="reference external" href="https://zhuanlan.zhihu.com/p/163455527">自回归语言模型 VS 自编码语言模型</a>、<a class="reference external" href="https://www.cnblogs.com/sandwichnlp/p/11947627.html#%E9%A2%84%E8%AE%AD%E7%BB%83%E4%BB%BB%E5%8A%A1%E7%AE%80%E4%BB%8B">预训练语言模型整理</a>；论文：<a class="reference external" href="http://jcip.cipsc.org.cn/CN/abstract/abstract3187.shtml">基于语言模型的预训练技术研究综述</a>等。</p>
<p>​因此，在 Transformer 中，我们需要建立一个 mask 函数，可以根据当下预测的时间阶段对输入语料进行 mask，被 mask 的信息就不会被模型得知，从而保证了模型只使用历史信息进行预测。mask 的实现方法如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">subsequent_mask</span><span class="p">(</span><span class="n">size</span><span class="p">):</span>
    <span class="s2">&quot;Mask out subsequent positions.&quot;</span>
    <span class="n">attn_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">size</span><span class="p">)</span>
    <span class="c1"># 注意力语料的形状</span>
    <span class="n">subsequent_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">attn_shape</span><span class="p">),</span> <span class="n">diagonal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span>
        <span class="n">torch</span><span class="o">.</span><span class="n">uint8</span>
    <span class="p">)</span>
    <span class="c1"># triu返回一个上三角矩阵，diagonal=1控制不保留主对角线</span>
    <span class="sd">&#39;&#39;&#39;</span>
<span class="sd">    一个例子：</span>
<span class="sd">    a = [[1,2,3],</span>
<span class="sd">         [4,5,6],</span>
<span class="sd">         [7,8,9]</span>
<span class="sd">    ]</span>
<span class="sd">    triu(a, diagonal=1)返回：</span>
<span class="sd">    [[0,2,3],</span>
<span class="sd">     [0,0,6],</span>
<span class="sd">     [0,0,0]]</span>
<span class="sd">     可见通过生成这样一个上三角矩阵，再取其中为0的位置，可以实现我们所需的未来信息遮蔽</span>
<span class="sd">    &#39;&#39;&#39;</span>
    <span class="k">return</span> <span class="n">subsequent_mask</span> <span class="o">==</span> <span class="mi">0</span>
</pre></div>
</div>
</section>
<section id="id5">
<h2>位置编码<a class="headerlink" href="#id5" title="永久链接至标题">#</a></h2>
<p>​Attention 机制可以实现良好的并行计算，但同时，其注意力计算的方式也导致序列中相对位置的丢失。在 RNN、LSTM 中，输入序列会沿着语句本身的顺序被依次递归处理，因此输入序列的顺序提供了极其重要的信息，这也和自然语言的本身特性非常吻合。但从上文对 Attention 机制的分析我们可以发现，在 Attention 机制的计算过程中，对于序列中的每一个 token，其他各个位置对其来说都是平等的，即“我喜欢你”和“你喜欢我”在 Attention 机制看来是完全相同的，但无疑这是 Attention 机制存在的一个巨大问题。因此，为使用序列顺序信息，保留序列中的相对位置信息，Transformer 采用了位置编码机制，该机制也在之后被多种模型沿用。</p>
<p>​位置编码，即根据序列中 token 的相对位置对其进行编码，再将位置编码加入词向量编码中。位置编码的方式有很多，Transformer 使用了正余弦函数来进行位置编码，其编码方式为：
$<span class="math notranslate nohighlight">\(
PE(pos, 2i) = sin(pos/10000^{2i/d_{model}})\\
PE(pos, 2i+1) = cos(pos/10000^{2i/d_{model}})
\)</span><span class="math notranslate nohighlight">\(
​上式中，pos 为 token 在句子中的位置，2i 和 2i+1 则是指示了 token 是奇数位置还是偶数位置，从上式中我们可以看出对于奇数位置的 token 和偶数位置的 token，Transformer 采用了不同的函数进行编码。我们以一个简单的例子来说明位置编码的计算过程：假如我们输入的是一个长度为 4 的句子&quot;I like to code&quot;，我们可以得到下面的词向量矩阵\)</span>\rm x<span class="math notranslate nohighlight">\(，其中每一行代表的就是一个词向量，\)</span>\rm x_0=[0.1,0.2,0.3,0.4]<span class="math notranslate nohighlight">\(对应的就是“I”的词向量，它的pos就是为0，以此类推，第二行代表的是“like”的词向量，它的pos就是1：
\)</span><span class="math notranslate nohighlight">\(
\rm x = \begin{bmatrix} 0.1 &amp; 0.2 &amp; 0.3 &amp; 0.4 \\ 0.2 &amp; 0.3 &amp; 0.4 &amp; 0.5 \\ 0.3 &amp; 0.4 &amp; 0.5 &amp; 0.6 \\ 0.4 &amp; 0.5 &amp; 0.6 &amp; 0.7 \end{bmatrix}
\)</span><span class="math notranslate nohighlight">\(
​则经过位置编码后的词向量为：
\)</span><span class="math notranslate nohighlight">\(
\rm x_{PE} = \begin{bmatrix} 0.1 &amp; 0.2 &amp; 0.3 &amp; 0.4 \\ 0.2 &amp; 0.3 &amp; 0.4 &amp; 0.5 \\ 0.3 &amp; 0.4 &amp; 0.5 &amp; 0.6 \\ 0.4 &amp; 0.5 &amp; 0.6 &amp; 0.7 \end{bmatrix} + \begin{bmatrix} \sin(\frac{0}{10000^0}) &amp; \cos(\frac{0}{10000^0}) &amp; \sin(\frac{0}{10000^{2/4}}) &amp; \cos(\frac{0}{10000^{2/4}}) \\ \sin(\frac{1}{10000^0}) &amp; \cos(\frac{1}{10000^0}) &amp; \sin(\frac{1}{10000^{2/4}}) &amp; \cos(\frac{1}{10000^{2/4}}) \\ \sin(\frac{2}{10000^0}) &amp; \cos(\frac{2}{10000^0}) &amp; \sin(\frac{2}{10000^{2/4}}) &amp; \cos(\frac{2}{10000^{2/4}}) \\ \sin(\frac{3}{10000^0}) &amp; \cos(\frac{3}{10000^0}) &amp; \sin(\frac{3}{10000^{2/4}}) &amp; \cos(\frac{3}{10000^{2/4}}) \end{bmatrix} = \begin{bmatrix} 0.1 &amp; 1.2 &amp; 0.3 &amp; 1.4 \\ 1.041 &amp; 0.84 &amp; 0.41 &amp; 1.49 \\ 1.209 &amp; -0.016 &amp; 0.52 &amp; 1.59 \\ 0.541 &amp; -0.489 &amp; 0.895 &amp; 1.655 \end{bmatrix}
\)</span>$
我们可以使用如下的代码来获取上述例子的位置编码：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="k">def</span> <span class="nf">PositionEncoding</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n</span><span class="o">=</span><span class="mi">10000</span><span class="p">):</span>
    <span class="n">P</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
    <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">d_model</span><span class="o">/</span><span class="mi">2</span><span class="p">)):</span>
            <span class="n">denominator</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">i</span><span class="o">/</span><span class="n">d_model</span><span class="p">)</span>
            <span class="n">P</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">k</span><span class="o">/</span><span class="n">denominator</span><span class="p">)</span>
            <span class="n">P</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">k</span><span class="o">/</span><span class="n">denominator</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">P</span>

<span class="n">P</span> <span class="o">=</span> <span class="n">PositionEncoding</span><span class="p">(</span><span class="n">seq_len</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">n</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">P</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="p">[[</span> <span class="mf">0.</span>          <span class="mf">1.</span>          <span class="mf">0.</span>          <span class="mf">1.</span>        <span class="p">]</span>
 <span class="p">[</span> <span class="mf">0.84147098</span>  <span class="mf">0.54030231</span>  <span class="mf">0.09983342</span>  <span class="mf">0.99500417</span><span class="p">]</span>
 <span class="p">[</span> <span class="mf">0.90929743</span> <span class="o">-</span><span class="mf">0.41614684</span>  <span class="mf">0.19866933</span>  <span class="mf">0.98006658</span><span class="p">]</span>
 <span class="p">[</span> <span class="mf">0.14112001</span> <span class="o">-</span><span class="mf">0.9899925</span>   <span class="mf">0.29552021</span>  <span class="mf">0.95533649</span><span class="p">]]</span>
</pre></div>
</div>
<p>这样的位置编码主要有两个好处：</p>
<ol class="simple">
<li><p>使 PE 能够适应比训练集里面所有句子更长的句子，假设训练集里面最长的句子是有 20 个单词，突然来了一个长度为 21 的句子，则使用公式计算的方法可以计算出第 21 位的 Embedding。</p></li>
<li><p>可以让模型容易地计算出相对位置，对于固定长度的间距 k，PE(pos+k) 可以用 PE(pos) 计算得到。因为 Sin(A+B) = Sin(A)Cos(B) + Cos(A)Sin(B), Cos(A+B) = Cos(A)Cos(B) - Sin(A)Sin(B)。</p></li>
</ol>
<p>​关于位置编码，有许多学者从数学的角度证明了该编码方式相对于其他更简单、直观的编码方式的优越性与必要性，由于本文重点在于代码的解析，此处不再赘述，感兴趣的读者可以查阅相关资料，如博客：<a class="reference external" href="https://kazemnejad.com/blog/transformer_architecture_positional_encoding/">Transformer Architecture: The Positional Encoding</a>、<a class="reference external" href="https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/">A Gentle Introduction to Positional Encoding in Transformer Models</a> 等。</p>
<p>​编码结果示例如下：</p>
<div align=center><img src="./figures/transformer_position_embedding.png" alt="image-20230129201913077" style="zoom:50%;"/></div>
<p>​位置编码的实现如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">5000</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">PositionalEncoding</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>

        <span class="c1"># Compute the positional encodings once in log space.</span>
        <span class="n">pe</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="c1"># 位置矩阵，初始化为0</span>
        <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># 公式中的POS</span>
        <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="o">-</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="p">)</span>
        <span class="c1"># 此处对公式中的指数计算取了个对数</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="c1"># 基于公式计算位置编码</span>
        <span class="n">pe</span> <span class="o">=</span> <span class="n">pe</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s2">&quot;pe&quot;</span><span class="p">,</span> <span class="n">pe</span><span class="p">)</span>
        <span class="c1"># 定义一组训练中不会改变的参数pe</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">pe</span><span class="p">[:,</span> <span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
        <span class="c1"># 向词向量中加入位置编码</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
</section>
<section id="id6">
<h2>最终建模<a class="headerlink" href="#id6" title="永久链接至标题">#</a></h2>
<p>​在完成上述实现后，再将其各个组件组合起来即可，此处使用一个函数建立完整模型：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">make_model</span><span class="p">(</span>
    <span class="n">src_vocab</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">d_ff</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span>
<span class="p">):</span>
    <span class="s2">&quot;Helper: Construct a model from hyperparameters.&quot;</span>
    <span class="n">c</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="n">MultiHeadedAttention</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
    <span class="c1"># 多头注意力层</span>
    <span class="n">ff</span> <span class="o">=</span> <span class="n">PositionwiseFeedForward</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
    <span class="c1"># 前馈神经网络层</span>
    <span class="n">position</span> <span class="o">=</span> <span class="n">PositionalEncoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
    <span class="c1"># 位置编码层</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">EncoderDecoder</span><span class="p">(</span>
        <span class="n">Encoder</span><span class="p">(</span><span class="n">EncoderLayer</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="p">(</span><span class="n">attn</span><span class="p">),</span> <span class="n">c</span><span class="p">(</span><span class="n">ff</span><span class="p">),</span> <span class="n">dropout</span><span class="p">),</span> <span class="n">N</span><span class="p">),</span>
        <span class="c1"># 编码器</span>
        <span class="n">Decoder</span><span class="p">(</span><span class="n">DecoderLayer</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="p">(</span><span class="n">attn</span><span class="p">),</span> <span class="n">c</span><span class="p">(</span><span class="n">attn</span><span class="p">),</span> <span class="n">c</span><span class="p">(</span><span class="n">ff</span><span class="p">),</span> <span class="n">dropout</span><span class="p">),</span> <span class="n">N</span><span class="p">),</span>
        <span class="c1"># 解码器，第一个attn为自注意力层，第二个attn为加入编码器输出的注意力层</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">Embeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">src_vocab</span><span class="p">),</span> <span class="n">c</span><span class="p">(</span><span class="n">position</span><span class="p">)),</span>
        <span class="c1"># 输入语料编码层</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">Embeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">),</span> <span class="n">c</span><span class="p">(</span><span class="n">position</span><span class="p">)),</span>
        <span class="c1"># 输出语料编码层</span>
        <span class="n">Generator</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">tgt_vocab</span><span class="p">),</span>
        <span class="c1"># 分类层</span>
    <span class="p">)</span>

    <span class="c1"># This was important from their code.</span>
    <span class="c1"># Initialize parameters with Glorot / fan_avg.</span>
    <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
        <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
            <span class="c1"># 初始化参数</span>
    <span class="k">return</span> <span class="n">model</span>
</pre></div>
</div>
</section>
<section id="id7">
<h2>训练<a class="headerlink" href="#id7" title="永久链接至标题">#</a></h2>
<p>​基于 Pytorch 实现的 Transformer 的训练大致沿用了 Pytorch 的框架，整体流程同使用 Pytorch 建立深度学习模型的流程大致相同，只不过实现了一些更底层的细节自定义。</p>
<section id="traning-loop">
<h3>Traning Loop<a class="headerlink" href="#traning-loop" title="永久链接至标题">#</a></h3>
<p>​作者并没有直接使用 Pytorch 提供的训练函数，而是自定义了一个用于记录训练过程的类：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">TrainState</span><span class="p">:</span>
    <span class="sd">&quot;&quot;&quot;Track number of steps, examples, and tokens processed&quot;&quot;&quot;</span>

    <span class="n">step</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>  <span class="c1"># Steps in the current epoch</span>
    <span class="n">accum_step</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>  <span class="c1"># Number of gradient accumulation steps</span>
    <span class="n">samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>  <span class="c1"># total # of examples used</span>
    <span class="n">tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>  <span class="c1"># total # of tokens processed</span>
</pre></div>
</div>
<p>​	接着，基于该类的实现定义了运行函数：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">run_epoch</span><span class="p">(</span>
    <span class="n">data_iter</span><span class="p">,</span><span class="c1"># 数据集</span>
    <span class="n">model</span><span class="p">,</span><span class="c1"># 模型</span>
    <span class="n">loss_compute</span><span class="p">,</span><span class="c1"># 损失计算函数</span>
    <span class="n">optimizer</span><span class="p">,</span><span class="c1"># 优化器</span>
    <span class="n">scheduler</span><span class="p">,</span><span class="c1"># 调度器</span>
    <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;train&quot;</span><span class="p">,</span><span class="c1"># 模式，训练或测试</span>
    <span class="n">accum_iter</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span><span class="c1"># 进行优化的轮数</span>
    <span class="n">train_state</span><span class="o">=</span><span class="n">TrainState</span><span class="p">(),</span>
<span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;Train a single epoch&quot;&quot;&quot;</span>
    <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="n">total_tokens</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">total_loss</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">tokens</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">n_accum</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data_iter</span><span class="p">):</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
            <span class="n">batch</span><span class="o">.</span><span class="n">src</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">tgt</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">src_mask</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">tgt_mask</span>
        <span class="p">)</span><span class="c1"># 模型的前向计算</span>
        <span class="n">loss</span><span class="p">,</span> <span class="n">loss_node</span> <span class="o">=</span> <span class="n">loss_compute</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">tgt_y</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">ntokens</span><span class="p">)</span>
        <span class="c1"># 计算当下的模型损失</span>
        <span class="c1"># loss_node = loss_node / accum_iter</span>
        <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;train&quot;</span> <span class="ow">or</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;train+log&quot;</span><span class="p">:</span>
            <span class="n">loss_node</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
            <span class="c1"># 返乡传播</span>
            <span class="n">train_state</span><span class="o">.</span><span class="n">step</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="n">train_state</span><span class="o">.</span><span class="n">samples</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">src</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            <span class="n">train_state</span><span class="o">.</span><span class="n">tokens</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">ntokens</span>
            <span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="n">accum_iter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="c1"># 每隔 accum_iter 轮优化一次</span>
                <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
                <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">(</span><span class="n">set_to_none</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
                <span class="n">n_accum</span> <span class="o">+=</span> <span class="mi">1</span>
                <span class="n">train_state</span><span class="o">.</span><span class="n">accum_step</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>

        <span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span>
        <span class="n">total_tokens</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">ntokens</span>
        <span class="n">tokens</span> <span class="o">+=</span> <span class="n">batch</span><span class="o">.</span><span class="n">ntokens</span>
        <span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="mi">40</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="p">(</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;train&quot;</span> <span class="ow">or</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;train+log&quot;</span><span class="p">):</span>
            <span class="n">lr</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span>
            <span class="n">elapsed</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
            <span class="nb">print</span><span class="p">(</span>
                <span class="p">(</span>
                    <span class="s2">&quot;Epoch Step: </span><span class="si">%6d</span><span class="s2"> | Accumulation Step: </span><span class="si">%3d</span><span class="s2"> | Loss: </span><span class="si">%6.2f</span><span class="s2"> &quot;</span>
                    <span class="o">+</span> <span class="s2">&quot;| Tokens / Sec: </span><span class="si">%7.1f</span><span class="s2"> | Learning Rate: </span><span class="si">%6.1e</span><span class="s2">&quot;</span>
                <span class="p">)</span>
                <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">n_accum</span><span class="p">,</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">batch</span><span class="o">.</span><span class="n">ntokens</span><span class="p">,</span> <span class="n">tokens</span> <span class="o">/</span> <span class="n">elapsed</span><span class="p">,</span> <span class="n">lr</span><span class="p">)</span>
            <span class="p">)</span>
            <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
            <span class="n">tokens</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">del</span> <span class="n">loss</span>
        <span class="k">del</span> <span class="n">loss_node</span>
    <span class="k">return</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="n">total_tokens</span><span class="p">,</span> <span class="n">train_state</span>
</pre></div>
</div>
</section>
<section id="id8">
<h3>优化器<a class="headerlink" href="#id8" title="永久链接至标题">#</a></h3>
<p>​在 Transformer 中，使用了 Adam 优化器，该优化器学习率的计算基于下式：
$<span class="math notranslate nohighlight">\(
\rm learning rate=d^{-0.5}_{model} * min(step\_num^{-0.5}, step\_num * warmup\_steps^{-1.5})
\)</span>$
​为何选择该优化器，以及该优化器有什么优势，感兴趣的读者可以查阅相关资料如：<a class="reference external" href="https://zhuanlan.zhihu.com/p/90169812">【Adam】优化算法浅析</a>来深入探究其内部数学原理，此处不再赘述。基于上式，其优化器实现如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">rate</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">model_size</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="n">warmup</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    we have to default the step to 1 for LambdaLR function</span>
<span class="sd">    to avoid zero raising to negative power.</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="n">step</span> <span class="o">=</span> <span class="mi">1</span>
    <span class="k">return</span> <span class="n">factor</span> <span class="o">*</span> <span class="p">(</span>
        <span class="n">model_size</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="nb">min</span><span class="p">(</span><span class="n">step</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">),</span> <span class="n">step</span> <span class="o">*</span> <span class="n">warmup</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">1.5</span><span class="p">))</span>
    <span class="p">)</span>
</pre></div>
</div>
<p>​Pytorch 也提供了各种优化器的调用 API，也包括 Adam，在实际使用中可以直接调用。</p>
</section>
<section id="id9">
<h3>正则化<a class="headerlink" href="#id9" title="永久链接至标题">#</a></h3>
<p>​在训练过程中，作者使用了标签平滑来实现正则化，从而提高模型预测的准确率。具体的，使用了 KL 散度来计算标签平滑，此处同样不再赘述标签平滑、KL 散度的原理，感兴趣的读者可以参阅下列博客：<a class="reference external" href="https://blog.csdn.net/Roaddd/article/details/114761023">【正则化】Label Smoothing详解</a>、<a class="reference external" href="https://zhuanlan.zhihu.com/p/100676922">Kullback-Leibler(KL)散度介绍</a>。标签平滑的代码实现如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LabelSmoothing</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s2">&quot;Implement label smoothing.&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">padding_idx</span><span class="p">,</span> <span class="n">smoothing</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">LabelSmoothing</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">KLDivLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><span class="p">)</span><span class="c1"># 使用 KL 散度计算</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">padding_idx</span> <span class="o">=</span> <span class="n">padding_idx</span> <span class="c1"># 遮掩部分index</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">confidence</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">smoothing</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span> <span class="o">=</span> <span class="n">smoothing</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_dist</span> <span class="o">=</span> <span class="kc">None</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
        <span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span>
        <span class="n">true_dist</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
        <span class="n">true_dist</span><span class="o">.</span><span class="n">fill_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">-</span> <span class="mi">2</span><span class="p">))</span>
        <span class="n">true_dist</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">target</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">confidence</span><span class="p">)</span>
        <span class="n">true_dist</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">mask</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">true_dist</span><span class="o">.</span><span class="n">index_fill_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">mask</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(),</span> <span class="mf">0.0</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_dist</span> <span class="o">=</span> <span class="n">true_dist</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">true_dist</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
</pre></div>
</div>
</section>
</section>
<section id="id10">
<h2>总结<a class="headerlink" href="#id10" title="永久链接至标题">#</a></h2>
<section id="id11">
<h3>优点<a class="headerlink" href="#id11" title="永久链接至标题">#</a></h3>
<p>​Transformer 作为 NLP 发展史上里程碑的模型，是具有较大创新性和指导性的。其创造性地抛弃了沿用了几十年的 CNN、RNN 架构，完全使用 Attention 机制来搭建网络并取得了良好的效果，帮助 Attention 机制站上了时代的舞台。而论及模型本身，Attention 机制的使用使其能够有效捕捉长距离相关性，解决了 NLP 领域棘手的长距依赖问题，同时，抛弃了 RNN 架构使其能够充分实现并行化，提升了模型计算能力。</p>
</section>
<section id="id12">
<h3>缺点<a class="headerlink" href="#id12" title="永久链接至标题">#</a></h3>
<p>​不可否认，Transformer 也存在诸多缺陷。最明显的一点是，提出该模型的论文名为《Attention Is All You Need》，但事实上我们真的仅仅只需要 Attention 吗？粗暴的抛弃 CNN 与 RNN 虽然非常炫技，但也使模型丧失了捕捉局部特征的能力，RNN + CNN + Attention 也许能够带来更好的效果。其次，Attention 机制失去了位置信息，虽然 Transformer 使用了 Positional Encoding 来补充位置信息，但只是权宜之计，没有改变其结构上的固有缺陷。最后，Attention 机制的参数量庞大，训练的门槛与成本也有了一定提高。因此，正是在 Transformer 提出之后，如 BERT、XLNet 等基于 Transformer 结构的强大、昂贵的预训练模型也逐步登上时代的舞台。</p>
</section>
</section>
<section id="id13">
<h2>参考材料：<a class="headerlink" href="#id13" title="永久链接至标题">#</a></h2>
<ol class="simple">
<li><p><a class="reference external" href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a></p></li>
<li><p>https://lilianweng.github.io/posts/2018-06-24-attention/</p></li>
<li><p>https://zhuanlan.zhihu.com/p/48508221</p></li>
</ol>
</section>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">文章结构</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="ViT%E8%A7%A3%E8%AF%BB.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">ViT解读</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>